package org.zjvis.datascience.spark.algorithm;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.mllib.feature.PCA;
import org.apache.spark.mllib.feature.PCAModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.zjvis.datascience.spark.util.OptionHelper;
import org.zjvis.datascience.spark.util.OutputResult;
import org.zjvis.datascience.spark.util.UtilTool;
import scala.collection.JavaConversions;

import java.util.*;

/**
 * @description Spark-PCA 数据降维算子
 * @date 2021-12-23
 */
public class PcaAlgorithm extends BaseAlgorithm {

    private String[] featureCols;

    private String targetTable;

    private String sourceTable;

    private String modelPath;

    private int k;


    public PcaAlgorithm(SparkSession sparkSession) {
        super(sparkSession);
    }

    public boolean parseParams(String[] args) {
        super.parseParams(args);
        options.addOption(OptionHelper.getSpecOption("source"));
        options.addOption(OptionHelper.getSpecOption("target"));
        options.addOption(OptionHelper.getSpecOption("featureCols"));
        options.addOption(
                OptionHelper.getSpecOption("k", "keyComponent", "principal components", true));

        boolean flag = this.initCommandLine(args);

        if (!flag) {
            logger.error("initCommandLine fail!!!");
            return false;
        }

        if (cmd.hasOption("f") || cmd.hasOption("featureCols")) {
            featureCols = cmd.getOptionValue("featureCols").split(",");
        }
        if (cmd.hasOption("s") || cmd.hasOption("source")) {
            sourceTable = cmd.getOptionValue("source");
        }
        if (cmd.hasOption("t") || cmd.hasOption("target")) {
            targetTable = cmd.getOptionValue("target");
        }
        if (cmd.hasOption("k") || cmd.hasOption("keyComponent")) {
            k = Integer.parseInt(cmd.getOptionValue("keyComponent"));
        }
        return true;
    }

    public boolean beginAlgorithm() {
        Dataset<Row> dataset = null;
        if (isHive) {
            String sql = UtilTool.buildSelectSql(featureCols, sourceTable, idCol);
            if (isSample) {
                dataset = sparkSession.sql(sql).sample(sampleRatio);
            } else {
                dataset = sparkSession.sql(sql);
            }
        } else {
            if (isSample) {
                String cacheTable = cacheUtil.modifyCacheTableName(sourceTable);
                if (cacheUtil.isCacheTableExists(cacheTable)) {
                    dataset = sparkSession.table(cacheTable)
                            .select(JavaConversions
                                    .asScalaBuffer(UtilTool.selectColumns(featureCols, idCol)));
                } else {
                    Dataset<Row> cacheRdd = UtilTool
                            .readFromGreenPlum(sparkSession, sourceTable, idCol, sampleNumber);
                    dataset = cacheRdd.select(
                            JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, idCol)));
                    if (!cacheUtil.cacheTableForDataset(cacheRdd, cacheTable)) {
                        logger.error("table = {} cache fail!!!", cacheTable);
                        return false;
                    }
                    logger.info("table = {} cache success", cacheTable);
                }
            } else {
                dataset = UtilTool.readFromGreenPlum(sparkSession, sourceTable, idCol)
                        .select(
                                JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, idCol)));
            }
        }

        VectorAssembler assembler = new VectorAssembler()
                .setHandleInvalid("skip")
                .setInputCols(featureCols)
                .setOutputCol("features");
        Dataset<Row> transDF = assembler.transform(dataset).select(idCol, "features")
                .persist(StorageLevel.MEMORY_AND_DISK());

        Map<String, DataType> schema = UtilTool.getDataTypeMaps(transDF.schema());
        PCAModel model = new PCA(k)
                .fit(transDF.toJavaRDD().map(row -> Vectors.fromML((DenseVector) row.get(1))));

        JavaRDD<Row> resultRdd = transDF.toJavaRDD().map(row -> {
            Vector vector = model.transform(Vectors.fromML((DenseVector) row.get(1)));
            List<Object> features = new ArrayList<>();
            features.add(row.get(0));
            for (Object e : vector.toArray()) {
                features.add(e);
            }
            return RowFactory.create(features.toArray());
        });

        StructType newSchema = UtilTool.createSchema(schema, Collections.singletonList(idCol), k);

        Dataset<Row> df = sparkSession.createDataFrame(resultRdd, newSchema);

        if (isHive) {
            this.saveHiveTable(df, targetTable);
        } else {
            UtilTool.saveGreenplumTable(df, targetTable);
        }

        modelPath = String.format(UtilTool.MODEL_PATH, algorithmName, uniqueKey);

        Map<String, Object> inputKV = new HashMap<>();
        List<String> outputTables = new ArrayList<>();
        outputTables.add(targetTable);

        String metaPath = String.format(UtilTool.RUN_META_PATH, algorithmName, uniqueKey);

        OutputResult outputResult = this.buildOutputResult(outputTables, 0, "", inputKV);

        if (isHive) {
            this.saveOutputResultForHive(sparkSession, outputResult, metaPath);
        } else {
            this.saveOutputResultForMysql(sparkSession, outputResult);
        }
        retResult = outputResult;
        transDF.unpersist();
        return true;
    }

}
